import torch
from accelerate import Accelerator
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup

from pruner.utils import find_layers

from .trainer import trainer


class sparsetrainer(trainer):
    def __init__(self, model, lr, num_warmup_steps, device):
        super().__init__(model, lr, num_warmup_steps, device)

    def mask_gradients(self, mask=None):
        layers = (
            self.model.model.layers
            if hasattr(self.model.model, "layers")
            else self.model.model.decoder.layers
        )
        cnt = 0
        for layer in layers:
            subset = find_layers(layer)
            for name in subset:
                if mask is not None:
                    subset[name].weight.grad[mask[cnt]] = 0
                else:
                    W = subset[name].weight.data
                    W_mask = W == 0
                    subset[name].weight.grad[W_mask] = 0
                cnt += 1

    def train(self, loader, epochs):
        accelerator = Accelerator()
        self.model.train()
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=len(loader) * epochs,
        )
        self.model, self.optimizer, loader, scheduler = accelerator.prepare(
            self.model, self.optimizer, loader, scheduler
        )
        total_steps = len(loader) * epochs
        progress_bar = tqdm(range(total_steps), desc="Sparse Training")
        for epoch in range(epochs):
            for batch in loader:
                with accelerator.autocast():
                    outputs = self.model(**batch)
                    loss = outputs.loss
                accelerator.backward(loss)
                self.mask_gradients()
                self.optimizer.step()
                scheduler.step()
                self.optimizer.zero_grad()
                progress_bar.update(1)
                progress_bar.set_postfix({"loss": loss.item()})
        progress_bar.close()

        return self.model
